{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "o12ifaotXaMg"
      },
      "source": [
        "##### Copyright 2019 Google LLC."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "cellView": "form",
        "colab": {},
        "colab_type": "code",
        "id": "1GBcPzQeXocT"
      },
      "outputs": [],
      "source": [
        "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "GSwlSaVxXuAE"
      },
      "source": [
        "# Closed Form Matting Energy\n",
        "\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\n",
        "  \u003ctd\u003e\n",
        "    \u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/graphics/blob/master/tensorflow_graphics/notebooks/matting.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e\n",
        "  \u003c/td\u003e\n",
        "  \u003ctd\u003e\n",
        "    \u003ca target=\"_blank\" href=\"https://github.com/tensorflow/graphics/blob/master/tensorflow_graphics/notebooks/matting.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on GitHub\u003c/a\u003e\n",
        "  \u003c/td\u003e\n",
        "\u003c/table\u003e"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "y3tbz-ntDzni"
      },
      "source": [
        "\n",
        "Matting is an important task in image editing where a novel background is combined with a given foreground to produce a new composite image. To achieve a plausible result, the foreground needs to be carefully extracted from a given image, i.e. preserving all the thin structures, before being inpainted over the new background. In image matting, the input image $I$ is assumed to be a linear combination of a foreground image $F$ and a background image $B$. For a pixel $j$ of $I$, the color of the pixel can therefore be expressed as $I_j = \\alpha_j F_j +(1-\\alpha_j)B_j$,\n",
        "where $\\alpha_j$ is the foreground opacity for the pixel $j$. The opacity image made of all the $\\alpha_j$ pixels is called a matte.\n",
        "<div align=\"center\">\n",
        "<img src=\"https://github.com/frcs/alternative-matting-laplacian/raw/master/GT04.png\" width=\"283\" height=\"200\" />\n",
        "<img src=\"https://github.com/frcs/alternative-matting-laplacian/raw/master/alpha0-GT04.png\" width=\"283\" height=\"200\" />\n",
        "</div>\n",
        "\n",
        "Using a trimap (white for foreground, black for background, and gray for unknown pixels)\n",
        "<div align=\"center\">\n",
        "<img src=\"https://github.com/frcs/alternative-matting-laplacian/raw/master/trimap-GT04.png\" width=\"283\" height=\"200\" />\n",
        "</div>\n",
        "\n",
        "or a set of scribbles (user strokes), an optimization problem can be formulated to retrieve the unknown pixel opacities. This colab demonstrates how to use the image matting loss implemented in TensorFlow Graphics to precisely segment out objects from images and have the ability to paste them on top of new backgrounds. This matting loss is derived from the paper titled \"A Closed Form Solution to Natural Image Matting\" from Levin et al. The loss was \"tensorized\" inspired by \"Deep-Energy: Unsupervised Training of Deep Neural Networks\" from Golts et al."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "9iWqeS9qYBAJ"
      },
      "source": [
        "## Setup \u0026 Imports\n",
        "If TensorFlow Graphics is not installed on your system, the following cell can install the TensorFlow Graphics package for you."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "X_v1AoLCYFp-"
      },
      "outputs": [],
      "source": [
        "!pip install tensorflow_graphics"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "VSoQOLZXYGnL"
      },
      "source": [
        "Now that TensorFlow Graphics is installed, let's import everything needed to run the demos contained in this notebook."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "-JnFAN7Ndzi7"
      },
      "outputs": [],
      "source": [
        "from __future__ import absolute_import\n",
        "from __future__ import division\n",
        "from __future__ import print_function\n",
        "\n",
        "import matplotlib.pyplot as plt\n",
        "import tensorflow as tf\n",
        "from tensorflow_graphics.image import matting\n",
        "from tqdm import tqdm\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "ni8fbo1OaHsj"
      },
      "source": [
        "## Import the image and trimap\n",
        "Download the image and trimap from [alphamatting.com](http://alphamatting.com/)."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "vBg2BCKPZCQ0"
      },
      "outputs": [],
      "source": [
        "# Download dataset from alphamatting.com\n",
        "!rm -rf input_training_lowres\n",
        "!rm -rf trimap_training_lowres\n",
        "!rm -rf gt_training_lowres\n",
        "\n",
        "!wget -q http://www.alphamatting.com/datasets/zip/input_training_lowres.zip\n",
        "!wget -q http://www.alphamatting.com/datasets/zip/trimap_training_lowres.zip\n",
        "!wget -q http://www.alphamatting.com/datasets/zip/gt_training_lowres.zip\n",
        "\n",
        "!unzip -q input_training_lowres.zip -d input_training_lowres\n",
        "!unzip -q trimap_training_lowres.zip -d trimap_training_lowres\n",
        "!unzip -q gt_training_lowres.zip -d gt_training_lowres"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "E9uDtWFfE5lP"
      },
      "outputs": [],
      "source": [
        "# Read and decode images\n",
        "source = tf.io.read_file('input_training_lowres/GT07.png')\n",
        "source = tf.cast(tf.io.decode_png(source), tf.float64) / 255.0\n",
        "source = tf.expand_dims(source, axis=0)\n",
        "trimap = tf.io.read_file('trimap_training_lowres/Trimap1/GT07.png')\n",
        "trimap = tf.cast(tf.io.decode_png(trimap), tf.float64) / 255.0\n",
        "trimap = tf.reduce_mean(trimap, axis=-1, keepdims=True)\n",
        "trimap = tf.expand_dims(trimap, axis=0)\n",
        "gt_matte = tf.io.read_file('gt_training_lowres/GT07.png')\n",
        "gt_matte = tf.cast(tf.io.decode_png(gt_matte), tf.float64) / 255.0\n",
        "gt_matte = tf.reduce_mean(gt_matte, axis=-1, keepdims=True)\n",
        "gt_matte = tf.expand_dims(gt_matte, axis=0)\n",
        "\n",
        "# Resize images to improve performance\n",
        "source = tf.image.resize(\n",
        "    source,\n",
        "    tf.shape(source)[1:3] // 2,\n",
        "    method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)\n",
        "trimap = tf.image.resize(\n",
        "    trimap,\n",
        "    tf.shape(trimap)[1:3] // 2,\n",
        "    method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)\n",
        "gt_matte = tf.image.resize(\n",
        "    gt_matte,\n",
        "    tf.shape(gt_matte)[1:3] // 2,\n",
        "    method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)\n",
        "\n",
        "# Show images\n",
        "figure = plt.figure(figsize=(22, 18))\n",
        "axes = figure.add_subplot(1, 3, 1)\n",
        "axes.grid(False)\n",
        "axes.set_title('Input image', fontsize=14)\n",
        "_= plt.imshow(source[0, ...].numpy())\n",
        "axes = figure.add_subplot(1, 3, 2)\n",
        "axes.grid(False)\n",
        "axes.set_title('Input trimap', fontsize=14)\n",
        "_= plt.imshow(trimap[0, ..., 0].numpy(), cmap='gray', vmin=0, vmax=1)\n",
        "axes = figure.add_subplot(1, 3, 3)\n",
        "axes.grid(False)\n",
        "axes.set_title('GT matte', fontsize=14)\n",
        "_= plt.imshow(gt_matte[0, ..., 0].numpy(), cmap='gray', vmin=0, vmax=1)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "gA7WTJI1Y0qZ"
      },
      "source": [
        "## Extract the foreground and background constraints from the trimap image"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "ZzLFoL5_YSgX"
      },
      "outputs": [],
      "source": [
        "# Extract the foreground and background constraints from the trimap image\n",
        "foreground = tf.cast(tf.equal(trimap, 1.0), tf.float64)\n",
        "background = tf.cast(tf.equal(trimap, 0.0), tf.float64)\n",
        "\n",
        "# Show foreground and background constraints\n",
        "figure = plt.figure(figsize=(22, 18))\n",
        "axes = figure.add_subplot(1, 2, 1)\n",
        "axes.grid(False)\n",
        "axes.set_title('Foreground constraints', fontsize=14)\n",
        "_= plt.imshow(foreground[0, ..., 0].numpy(), cmap='gray', vmin=0, vmax=1)\n",
        "axes = figure.add_subplot(1, 2, 2)\n",
        "axes.grid(False)\n",
        "axes.set_title('Background constraints', fontsize=14)\n",
        "_= plt.imshow(background[0, ..., 0].numpy(), cmap='gray', vmin=0, vmax=1)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "INHsXVYHZdIi"
      },
      "source": [
        "## Setup \u0026 run the optimization"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "Ya-cECcCajJ2"
      },
      "source": [
        "Setup the matting loss function using TensorFlow Graphics and run the Adam optimizer for 400 iterations."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "BYSuaxkRaaxz"
      },
      "outputs": [],
      "source": [
        "# Initialize the matte with random values\n",
        "matte_shape = tf.concat((tf.shape(source)[:-1], (1,)), axis=-1)\n",
        "matte = tf.Variable(\n",
        "    tf.random.uniform(\n",
        "        shape=matte_shape, minval=0.0, maxval=1.0, dtype=tf.float64))\n",
        "# Create the closed form matting Laplacian\n",
        "laplacian, _ = matting.build_matrices(source)\n",
        "\n",
        "# Function computing the loss and applying the gradient\n",
        "@tf.function\n",
        "def optimize(optimizer):\n",
        "  with tf.GradientTape() as tape:\n",
        "    tape.watch(matte)\n",
        "    # Compute a loss enforcing the trimap constraints\n",
        "    constraints = tf.reduce_mean((foreground + background) *\n",
        "                                 tf.math.squared_difference(matte, foreground))\n",
        "    # Compute the matting loss\n",
        "    smoothness = matting.loss(matte, laplacian)\n",
        "    # Sum up the constraint and matting losses\n",
        "    total_loss = 100 * constraints + smoothness\n",
        "  # Compute and apply the gradient to the matte\n",
        "  gradient = tape.gradient(total_loss, [matte])\n",
        "  optimizer.apply_gradients(zip(gradient, (matte,)))\n",
        "\n",
        "# Run the Adam optimizer for 400 iterations\n",
        "optimizer = tf.optimizers.Adam(learning_rate=1.0)\n",
        "nb_iterations = 400\n",
        "for it in tqdm(range(nb_iterations)):\n",
        "  optimize(optimizer)\n",
        "\n",
        "# Clip the matte value between 0 and 1\n",
        "matte = tf.clip_by_value(matte, 0.0, 1.0)\n",
        "\n",
        "# Display the results\n",
        "figure = plt.figure(figsize=(22, 18))\n",
        "axes = figure.add_subplot(1, 3, 1)\n",
        "axes.grid(False)\n",
        "axes.set_title('Input image', fontsize=14)\n",
        "plt.imshow(source[0, ...].numpy())\n",
        "axes = figure.add_subplot(1, 3, 2)\n",
        "axes.grid(False)\n",
        "axes.set_title('Input trimap', fontsize=14)\n",
        "_= plt.imshow(trimap[0, ..., 0].numpy(), cmap='gray', vmin=0, vmax=1)\n",
        "axes = figure.add_subplot(1, 3, 3)\n",
        "axes.grid(False)\n",
        "axes.set_title('Matte', fontsize=14)\n",
        "_= plt.imshow(matte[0, ..., 0].numpy(), cmap='gray', vmin=0, vmax=1)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "Zk2k9-3TW3Q0"
      },
      "source": [
        "### Compositing\n",
        "Let's now composite our extracted object on top of a new background!"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {},
        "colab_type": "code",
        "id": "rgXBjBNgW2rM"
      },
      "outputs": [],
      "source": [
        "!wget -q https://p2.piqsels.com/preview/861/934/460/concrete-texture-background-backdrop.jpg\n",
        "background = tf.io.read_file('concrete-texture-background-backdrop.jpg')\n",
        "background = tf.cast(tf.io.decode_jpeg(background), tf.float64) / 255.0\n",
        "background = tf.expand_dims(background, axis=0)\n",
        "\n",
        "# Resize images to improve performance\n",
        "background = tf.image.resize(\n",
        "    background,\n",
        "    tf.shape(source)[1:3],\n",
        "    method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)\n",
        "\n",
        "# Inpaint the foreground over a new background\n",
        "inpainted_black = matte * source\n",
        "inpainted_concrete = matte * source + (1.0 - matte) * background\n",
        "\n",
        "# Display the results\n",
        "figure = plt.figure(figsize=(22, 18))\n",
        "axes = figure.add_subplot(1, 2, 1)\n",
        "axes.grid(False)\n",
        "axes.set_title('Inpainted black', fontsize=14)\n",
        "_= plt.imshow(inpainted_black[0, ...].numpy())\n",
        "axes = figure.add_subplot(1, 2, 2)\n",
        "axes.grid(False)\n",
        "axes.set_title('Inpainted concrete', fontsize=14)\n",
        "_= plt.imshow(inpainted_concrete[0, ...].numpy())"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "IDduD_dXXTVu"
      },
      "source": [
        "Note that the inpainting is approximate as we did not recover the real foreground $F_j = \\frac{I_j - (1-\\alpha_j)B_j}{\\alpha_j } $, which also necessitates an estimation of the background color."
      ]
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [],
      "last_runtime": {
        "build_target": "//vr/perception/makkari/colab:makkari",
        "kind": "private"
      },
      "name": "Matting.ipynb",
      "provenance": [],
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
